import tensorflow as tf
import gradio as gr
import numpy as np
from PIL import Image

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(_, _), (X_test, y_test) = mnist.load_data()

# 加载保存的模型
try:
    model = tf.keras.models.load_model('best_model.h5')
except FileNotFoundError:
    print("无法找到保存的模型文件")
    exit()
except Exception as e:
    print("加载模型时出错:", str(e))
    exit()

# 定义预处理函数
def preprocess(image):
    image = Image.fromarray(image)  # 将数组转换为图像对象
    image = image.resize((28, 28)).convert('L')  # 调整图像大小并转换为灰度图像
    image_array = np.array(image)  # 将图像转换为NumPy数组
    normalized_image = image_array / 255.0  # 对图像像素进行归一化
    reshaped_image = normalized_image.reshape((1, 28, 28, 1))  # 调整图像形状以适应模型的输入
    return reshaped_image

# 定义预测函数
def predict(image):
    preprocessed_image = preprocess(image)  # 预处理输入图像
    predicted_digit = np.argmax(model.predict(preprocessed_image))  # 使用模型进行预测
    return str(predicted_digit)

# 创建Gradio界面
iface = gr.Interface(fn=predict, inputs='sketchpad', outputs='label')

# 启动Gradio界面
iface.launch()